{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Character based RNN language model trained on 'The Complete Works of William Shakespeare'\n",
    "Based on http://karpathy.github.io/2015/05/21/rnn-effectiveness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "using Pkg; for p in (\"Knet\",\"AutoGrad\"); haskey(Pkg.installed(),p) || Pkg.add(p); end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "RNNTYPE = :lstm\n",
    "BATCHSIZE = 256\n",
    "SEQLENGTH = 100\n",
    "INPUTSIZE = 168\n",
    "VOCABSIZE = 84\n",
    "HIDDENSIZE = 334\n",
    "NUMLAYERS = 1\n",
    "DROPOUT = 0.0\n",
    "LR=0.001\n",
    "BETA_1=0.9\n",
    "BETA_2=0.999\n",
    "EPS=1e-08\n",
    "EPOCHS = 30;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "gpu() = 0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(\"4934845-element Array{UInt8,1}\", \"526731-element Array{UInt8,1}\", \"84-element Array{Char,1}\")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load 'The Complete Works of William Shakespeare'\n",
    "using Knet; @show gpu()\n",
    "include(Knet.dir(\"data\",\"gutenberg.jl\"))\n",
    "trn,tst,chars = shakespeare()\n",
    "map(summary,(trn,tst,chars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r\n",
      "    Cheated of feature by dissembling nature,\r\n",
      "    Deform'd, unfinish'd, sent before my time\r\n",
      "    Into this breathing world scarce half made up,\r\n",
      "    And that so lamely and unfashionable\r\n",
      " \n"
     ]
    }
   ],
   "source": [
    "# Print a sample\n",
    "println(string(chars[trn[1020:1210]]...)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(192, 20)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Minibatch data\n",
    "function mb(a)\n",
    "    N = div(length(a),BATCHSIZE)\n",
    "    x = reshape(a[1:N*BATCHSIZE],N,BATCHSIZE)' # reshape full data to (B,N) with contiguous rows\n",
    "    minibatch(x[:,1:N-1], x[:,2:N], SEQLENGTH) # split into (B,T) blocks \n",
    "end\n",
    "dtrn,dtst = mb(trn),mb(tst)\n",
    "map(length, (dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define model\n",
    "function initmodel()\n",
    "    w(d...)=KnetArray(xavier(Float32,d...))\n",
    "    b(d...)=KnetArray(zeros(Float32,d...))\n",
    "    r,wr = rnninit(INPUTSIZE,HIDDENSIZE,rnnType=RNNTYPE,numLayers=NUMLAYERS,dropout=DROPOUT)\n",
    "    wx = w(INPUTSIZE,VOCABSIZE)\n",
    "    wy = w(VOCABSIZE,HIDDENSIZE)\n",
    "    by = b(VOCABSIZE,1)\n",
    "    return r,wr,wx,wy,by\n",
    "end;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define loss and its gradient\n",
    "function predict(ws,xs,hx,cx;pdrop=0)\n",
    "    r,wr,wx,wy,by = ws\n",
    "    x = wx[:,xs]                                    # xs=(B,T) x=(X,B,T)\n",
    "    x = dropout(x,pdrop)\n",
    "    y,hy,cy = rnnforw(r,wr,x,hx,cx,hy=true,cy=true) # y=(H,B,T) hy=cy=(H,B,L)\n",
    "    y = dropout(y,pdrop)\n",
    "    y2 = reshape(y,size(y,1),size(y,2)*size(y,3))   # y2=(H,B*T)\n",
    "    return wy*y2.+by, hy, cy\n",
    "end\n",
    "\n",
    "function loss(w,x,y,h;o...)\n",
    "    py,hy,cy = predict(w,x,h...;o...)\n",
    "    h[1],h[2] = value(hy),value(cy)\n",
    "    return nll(py,y)\n",
    "end\n",
    "using AutoGrad: gradloss\n",
    "lossgradient = gradloss(loss);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and test loops\n",
    "function train(model,data,optim)\n",
    "    hiddens = Any[nothing,nothing]\n",
    "    Σ,N=0,0\n",
    "    for (x,y) in data\n",
    "        grads,loss1 = lossgradient(model,x,y,hiddens;pdrop=DROPOUT)\n",
    "        update!(model, grads, optim)\n",
    "        Σ,N=Σ+loss1,N+1\n",
    "    end\n",
    "    return Σ/N\n",
    "end\n",
    "\n",
    "function test(model,data)\n",
    "    hiddens = Any[nothing,nothing]\n",
    "    Σ,N=0,0\n",
    "    for (x,y) in data\n",
    "        Σ,N = Σ+loss(model,x,y,hiddens), N+1\n",
    "    end\n",
    "    return Σ/N\n",
    "end; "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Training...\n",
      "└ @ Main In[9]:7\n",
      "┌ Warning: `getval` is deprecated, use `value` instead.\n",
      "│   caller = #loss#8(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:pdrop,),Tuple{Float64}}}, ::Function, ::Param{Tuple{RNN,KnetArray{Float32,3},KnetArray{Float32,2},KnetArray{Float32,2},KnetArray{Float32,2}}}, ::Array{UInt8,2}, ::Array{UInt8,2}, ::Array{Any,1}) at In[7]:14\n",
      "└ @ Main ./In[7]:14\n",
      "┌ Warning: `getval` is deprecated, use `value` instead.\n",
      "│   caller = #loss#8(::Base.Iterators.Pairs{Symbol,Float64,Tuple{Symbol},NamedTuple{(:pdrop,),Tuple{Float64}}}, ::Function, ::Param{Tuple{RNN,KnetArray{Float32,3},KnetArray{Float32,2},KnetArray{Float32,2},KnetArray{Float32,2}}}, ::Array{UInt8,2}, ::Array{UInt8,2}, ::Array{Any,1}) at In[7]:14\n",
      "└ @ Main ./In[7]:14\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 24.699409 seconds (13.49 M allocations: 856.929 MiB, 1.45% gc time)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Warning: `getval` is deprecated, use `value` instead.\n",
      "│   caller = #loss#8(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Tuple{RNN,KnetArray{Float32,3},KnetArray{Float32,2},KnetArray{Float32,2},KnetArray{Float32,2}}, ::Array{UInt8,2}, ::Array{UInt8,2}, ::Array{Any,1}) at In[7]:14\n",
      "└ @ Main ./In[7]:14\n",
      "┌ Warning: `getval` is deprecated, use `value` instead.\n",
      "│   caller = #loss#8(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Function, ::Tuple{RNN,KnetArray{Float32,3},KnetArray{Float32,2},KnetArray{Float32,2},KnetArray{Float32,2}}, ::Array{UInt8,2}, ::Array{UInt8,2}, ::Array{Any,1}) at In[7]:14\n",
      "└ @ Main ./In[7]:14\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0.964949 seconds (1.39 M allocations: 83.647 MiB, 2.74% gc time)\n",
      "(:epoch, 1, :trnppl, 14.17144f0, :tstppl, 8.031225f0)\n",
      " 18.063935 seconds (275.04 k allocations: 172.994 MiB, 0.15% gc time)\n",
      "  0.575244 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 2, :trnppl, 6.8886585f0, :tstppl, 6.1861835f0)\n",
      " 18.169467 seconds (275.45 k allocations: 172.999 MiB, 0.10% gc time)\n",
      "  0.582032 seconds (10.64 k allocations: 13.383 MiB, 2.66% gc time)\n",
      "(:epoch, 3, :trnppl, 5.634745f0, :tstppl, 5.3615026f0)\n",
      " 18.233755 seconds (277.01 k allocations: 173.022 MiB, 0.12% gc time)\n",
      "  0.583114 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 4, :trnppl, 4.9907575f0, :tstppl, 4.882069f0)\n",
      " 18.256127 seconds (276.47 k allocations: 173.014 MiB, 0.08% gc time)\n",
      "  0.581267 seconds (10.64 k allocations: 13.383 MiB, 0.31% gc time)\n",
      "(:epoch, 5, :trnppl, 4.5789294f0, :tstppl, 4.5613937f0)\n",
      " 18.254528 seconds (277.01 k allocations: 173.022 MiB, 0.13% gc time)\n",
      "  0.586195 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 6, :trnppl, 4.288923f0, :tstppl, 4.3301454f0)\n",
      " 18.321975 seconds (276.47 k allocations: 173.014 MiB, 0.23% gc time)\n",
      "  0.582643 seconds (10.64 k allocations: 13.383 MiB, 0.57% gc time)\n",
      "(:epoch, 7, :trnppl, 4.07812f0, :tstppl, 4.1592717f0)\n",
      " 18.299795 seconds (277.01 k allocations: 173.022 MiB, 0.15% gc time)\n",
      "  0.582429 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 8, :trnppl, 3.9208643f0, :tstppl, 4.044902f0)\n",
      " 18.308532 seconds (276.47 k allocations: 173.014 MiB, 0.18% gc time)\n",
      "  0.576081 seconds (10.64 k allocations: 13.383 MiB, 0.51% gc time)\n",
      "(:epoch, 9, :trnppl, 3.8051097f0, :tstppl, 3.9437194f0)\n",
      " 18.360418 seconds (277.01 k allocations: 173.022 MiB, 0.50% gc time)\n",
      "  0.579799 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 10, :trnppl, 3.702473f0, :tstppl, 3.8641992f0)\n",
      " 18.286052 seconds (276.47 k allocations: 173.015 MiB, 0.30% gc time)\n",
      "  0.580104 seconds (10.64 k allocations: 13.383 MiB, 1.16% gc time)\n",
      "(:epoch, 11, :trnppl, 3.6221309f0, :tstppl, 3.8023968f0)\n",
      " 18.295108 seconds (277.01 k allocations: 173.022 MiB, 0.14% gc time)\n",
      "  0.579896 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 12, :trnppl, 3.5546553f0, :tstppl, 3.743243f0)\n",
      " 18.306134 seconds (276.47 k allocations: 173.014 MiB, 0.12% gc time)\n",
      "  0.583793 seconds (10.64 k allocations: 13.383 MiB, 1.13% gc time)\n",
      "(:epoch, 13, :trnppl, 3.4973187f0, :tstppl, 3.6886632f0)\n",
      " 18.311764 seconds (277.01 k allocations: 173.022 MiB, 0.09% gc time)\n",
      "  0.580448 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 14, :trnppl, 3.4478593f0, :tstppl, 3.64481f0)\n",
      " 18.280997 seconds (276.47 k allocations: 173.015 MiB, 0.10% gc time)\n",
      "  0.575429 seconds (10.64 k allocations: 13.383 MiB, 0.29% gc time)\n",
      "(:epoch, 15, :trnppl, 3.403747f0, :tstppl, 3.6052387f0)\n",
      " 18.306076 seconds (277.01 k allocations: 173.022 MiB, 0.12% gc time)\n",
      "  0.583030 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 16, :trnppl, 3.365376f0, :tstppl, 3.574168f0)\n",
      " 18.310505 seconds (276.47 k allocations: 173.014 MiB, 0.09% gc time)\n",
      "  0.581525 seconds (10.64 k allocations: 13.383 MiB, 0.46% gc time)\n",
      "(:epoch, 17, :trnppl, 3.3322759f0, :tstppl, 3.544019f0)\n",
      " 18.310224 seconds (277.01 k allocations: 173.022 MiB, 0.10% gc time)\n",
      "  0.579436 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 18, :trnppl, 3.3026357f0, :tstppl, 3.5193627f0)\n",
      " 18.314848 seconds (276.47 k allocations: 173.014 MiB, 0.21% gc time)\n",
      "  0.579691 seconds (10.64 k allocations: 13.383 MiB, 0.27% gc time)\n",
      "(:epoch, 19, :trnppl, 3.2748566f0, :tstppl, 3.4941688f0)\n",
      " 18.341136 seconds (277.01 k allocations: 173.022 MiB, 0.07% gc time)\n",
      "  0.587400 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 20, :trnppl, 3.2502308f0, :tstppl, 3.4708767f0)\n",
      " 18.300109 seconds (276.47 k allocations: 173.014 MiB, 0.09% gc time)\n",
      "  0.578100 seconds (10.64 k allocations: 13.383 MiB, 0.27% gc time)\n",
      "(:epoch, 21, :trnppl, 3.227919f0, :tstppl, 3.4508634f0)\n",
      " 18.283549 seconds (277.01 k allocations: 173.023 MiB, 0.10% gc time)\n",
      "  0.577311 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 22, :trnppl, 3.2075968f0, :tstppl, 3.4330473f0)\n",
      " 18.297627 seconds (276.47 k allocations: 173.014 MiB, 0.11% gc time)\n",
      "  0.580631 seconds (10.64 k allocations: 13.383 MiB, 0.27% gc time)\n",
      "(:epoch, 23, :trnppl, 3.188219f0, :tstppl, 3.4195304f0)\n",
      " 18.312334 seconds (277.01 k allocations: 173.022 MiB, 0.09% gc time)\n",
      "  0.581273 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 24, :trnppl, 3.1699858f0, :tstppl, 3.4076266f0)\n",
      " 18.304316 seconds (276.47 k allocations: 173.014 MiB, 0.17% gc time)\n",
      "  0.583314 seconds (10.64 k allocations: 13.383 MiB, 0.39% gc time)\n",
      "(:epoch, 25, :trnppl, 3.1535237f0, :tstppl, 3.397515f0)\n",
      " 18.264337 seconds (277.01 k allocations: 173.022 MiB, 0.18% gc time)\n",
      "  0.581945 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 26, :trnppl, 3.1381412f0, :tstppl, 3.3878815f0)\n",
      " 18.285340 seconds (276.47 k allocations: 173.014 MiB, 0.14% gc time)\n",
      "  0.585088 seconds (10.64 k allocations: 13.383 MiB, 0.44% gc time)\n",
      "(:epoch, 27, :trnppl, 3.1238694f0, :tstppl, 3.3777041f0)\n",
      " 18.281101 seconds (277.01 k allocations: 173.022 MiB, 0.08% gc time)\n",
      "  0.580616 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 28, :trnppl, 3.1108415f0, :tstppl, 3.3685112f0)\n",
      " 18.269055 seconds (276.47 k allocations: 173.015 MiB, 0.11% gc time)\n",
      "  0.575343 seconds (10.64 k allocations: 13.383 MiB, 0.48% gc time)\n",
      "(:epoch, 29, :trnppl, 3.0995467f0, :tstppl, 3.3611314f0)\n",
      " 18.319557 seconds (277.01 k allocations: 173.022 MiB, 0.10% gc time)\n",
      "  0.583192 seconds (7.92 k allocations: 13.342 MiB)\n",
      "(:epoch, 30, :trnppl, 3.0894253f0, :tstppl, 3.3524172f0)\n",
      "573.280564 seconds (24.33 M allocations: 6.253 GiB, 0.21% gc time)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "\"Tuple{RNN,KnetArray{Float32,3},KnetArray{Float32,2},KnetArray{Float32,2},KnetArray{Float32,2}}\""
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train model or load from file if exists\n",
    "model=optim=nothing; \n",
    "Knet.gc()\n",
    "if !isfile(\"shakespeare.jld2\")\n",
    "    model = initmodel()\n",
    "    optim = optimizers(model, Adam; lr=LR, beta1=BETA_1, beta2=BETA_2, eps=EPS);    \n",
    "    @info(\"Training...\")\n",
    "    @time for epoch in 1:EPOCHS\n",
    "        @time trnloss = train(model,dtrn,optim) # ~18 seconds\n",
    "        @time tstloss = test(model,dtst)        # ~0.5 seconds\n",
    "        println((:epoch, epoch, :trnppl, exp(trnloss), :tstppl, exp(tstloss)))\n",
    "    end\n",
    "    Knet.save(\"shakespeare.jld2\",\"model\",model)\n",
    "else\n",
    "    model = Knet.load(\"shakespeare.jld2\",\"model\")\n",
    "end\n",
    "summary(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "                         Enter a Seins of\n",
      "                       No. Marcorius Marcious Comanio, there is nothing of your England.\n",
      "\n",
      "  Nurse. Gods fall, devouring.\n",
      "\n",
      "         Enter angar\n",
      "\n",
      "                         Enter MISTRESS QUINDER and to LAVINIA\n",
      "\n",
      "  BOLINBBRO. I never take a messenger (look-like musicon\n",
      "    A gar of my good master's true. He may wejt,\n",
      "    Breathe himself, I risely: they mook'd him,\n",
      "    And sit on o'er your emports, if you were first of all\n",
      "    To see thou takes from faith, and fly but Scotlif her!\n",
      "    But look a My life too, O hoop on quarrels!\n",
      "    I will confess you to come for thee.\n",
      "  ROSALIND. Look should all undertake the old body of Cailia\n",
      "    Again, whence it is spurn.\n",
      "  NERISSA. You may think on's stame.\n",
      "  POLICE. Let us so low thyself, and vest so many loss.\n",
      "  SAFERS. His house?\n",
      "  SECOND SENATOR. They have itsly accompaty?\n",
      "  IAGO. I know wild love what he shall civil grief do curl'd not upon me.\n",
      "                                        \n"
     ]
    }
   ],
   "source": [
    "# Sample from trained model\n",
    "function generate(model,n)\n",
    "    function sample(y)\n",
    "        p,r=Array(exp.(Array(y).-logsumexp(y))),rand()\n",
    "        for j=1:length(p); (r -= p[j]) < 0 && return j; end\n",
    "    end\n",
    "    h,c = nothing,nothing\n",
    "    x = something(findfirst(isequal('\\n'), chars), 0)\n",
    "    for i=1:n\n",
    "        y,h,c = predict(model,[x],h,c)\n",
    "        x = sample(y)\n",
    "        print(chars[x])\n",
    "    end\n",
    "    println()\n",
    "end\n",
    "\n",
    "generate(model,1000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 1.0.0",
   "language": "julia",
   "name": "julia-1.0"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.0.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
